{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:46:41.585713Z",
     "start_time": "2019-06-15T17:46:40.892585Z"
    }
   },
   "outputs": [],
   "source": [
    "# -*- coding: utf-8 -*-\n",
    "import os, sys, glob, argparse, json\n",
    "import pandas as pd\n",
    "import numpy as np\n",
    "from tqdm import tqdm, tqdm_notebook\n",
    "\n",
    "import time, datetime\n",
    "import pdb, traceback\n",
    "\n",
    "import cv2\n",
    "# import imagehash\n",
    "from PIL import Image\n",
    "\n",
    "from sklearn.model_selection import train_test_split, StratifiedKFold\n",
    "\n",
    "import torch\n",
    "torch.manual_seed(0)\n",
    "torch.backends.cudnn.deterministic = False\n",
    "torch.backends.cudnn.benchmark = True\n",
    "\n",
    "import torchvision.models as models\n",
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import torch.optim as optim\n",
    "from torch.autograd import Variable\n",
    "from torch.utils.data.dataset import Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:46:41.600841Z",
     "start_time": "2019-06-15T17:46:41.590890Z"
    }
   },
   "outputs": [],
   "source": [
    "import logging\n",
    "logging.basicConfig(level=logging.DEBUG, filename='example.log',\n",
    "                    format='%(asctime)s - %(filename)s[line:%(lineno)d]: %(message)s')  # \n",
    "\n",
    "class QRDataset(Dataset):\n",
    "    def __init__(self, img_json, transform=None):\n",
    "        self.img_json = img_json\n",
    "        \n",
    "        if transform is not None:\n",
    "            self.transform = transform\n",
    "        else:\n",
    "            self.transform = None\n",
    "    \n",
    "    def __getitem__(self, index):\n",
    "        start_time = time.time()\n",
    "        \n",
    "        img = Image.open(os.path.join('../data/data/', self.img_json[index]['name']))\n",
    "        \n",
    "        if self.transform is not None:\n",
    "            img = self.transform(img)\n",
    "        \n",
    "        img_label_idx = self.img_json[index]['text'].strip()\n",
    "        label0 = np.array(self.char2idx(img_label_idx[0]))\n",
    "        label1 = np.array(self.char2idx(img_label_idx[1]))\n",
    "        label2 = np.array(self.char2idx(img_label_idx[2]))\n",
    "        label3 = np.array(self.char2idx(img_label_idx[3]))\n",
    "        label4 = np.array(self.char2idx(img_label_idx[4]))\n",
    "        label5 = np.array(self.char2idx(img_label_idx[5]))\n",
    "        label6 = np.array(self.char2idx(img_label_idx[6]))\n",
    "        label7 = np.array(self.char2idx(img_label_idx[7]))\n",
    "        label8 = np.array(self.char2idx(img_label_idx[8]))\n",
    "        label9 = np.array(self.char2idx(img_label_idx[9]))\n",
    "        \n",
    "        return img, torch.from_numpy(label0), torch.from_numpy(label1), \\\n",
    "                torch.from_numpy(label2), torch.from_numpy(label3), torch.from_numpy(label4),\\\n",
    "                torch.from_numpy(label5), torch.from_numpy(label6), torch.from_numpy(label7),\\\n",
    "                torch.from_numpy(label8), torch.from_numpy(label9)\n",
    "    \n",
    "    def __len__(self):\n",
    "        return len(self.img_json)\n",
    "    \n",
    "    def char2idx(self, ch):\n",
    "        return '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'.find(ch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:46:50.454871Z",
     "start_time": "2019-06-15T17:46:50.436948Z"
    }
   },
   "outputs": [],
   "source": [
    "class RMB_Net(nn.Module):\n",
    "    def __init__(self):\n",
    "        super(RMB_Net, self).__init__()\n",
    "        \n",
    "        feat_size = 2048\n",
    "        self.fc0 = nn.Linear(feat_size, 36)\n",
    "        self.fc1 = nn.Linear(feat_size, 36)\n",
    "        self.fc2 = nn.Linear(feat_size, 36)\n",
    "        self.fc3 = nn.Linear(feat_size, 36)\n",
    "        self.fc4 = nn.Linear(feat_size, 36)\n",
    "        self.fc5 = nn.Linear(feat_size, 36)\n",
    "        self.fc6 = nn.Linear(feat_size, 36)\n",
    "        self.fc7 = nn.Linear(feat_size, 36)\n",
    "        self.fc8 = nn.Linear(feat_size, 36)\n",
    "        self.fc9 = nn.Linear(feat_size, 36)\n",
    "        \n",
    "        model = models.resnet50(True)\n",
    "        model = torch.nn.Sequential(*(list(model.children())[:-1]))\n",
    "        self.resnet = model\n",
    "        \n",
    "    def forward(self, img):\n",
    "        feat = self.resnet(img)\n",
    "        feat = feat.reshape(feat.size(0), -1)\n",
    "        \n",
    "        out0 = self.fc0(feat)\n",
    "        out1 = self.fc1(feat)\n",
    "        out2 = self.fc2(feat)\n",
    "        out3 = self.fc3(feat)\n",
    "        out4 = self.fc4(feat)\n",
    "        out5 = self.fc5(feat)\n",
    "        out6 = self.fc6(feat)\n",
    "        out7 = self.fc7(feat)\n",
    "        out8 = self.fc8(feat)\n",
    "        out9 = self.fc9(feat)\n",
    "        \n",
    "        return F.log_softmax(out0, dim=1), F.log_softmax(out1, dim=1), F.log_softmax(out2, dim=1), \\\n",
    "                F.log_softmax(out3, dim=1), F.log_softmax(out4, dim=1), F.log_softmax(out5, dim=1), \\\n",
    "                F.log_softmax(out6, dim=1), F.log_softmax(out7, dim=1), F.log_softmax(out8, dim=1), \\\n",
    "                 F.log_softmax(out9, dim=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:34:12.726682Z",
     "start_time": "2019-06-15T17:34:12.711839Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "3"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "'0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ'.find('3')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:34:13.469609Z",
     "start_time": "2019-06-15T17:34:13.226586Z"
    }
   },
   "outputs": [],
   "source": [
    "model = models.resnet18(True)\n",
    "model = torch.nn.Sequential(*(list(model.children())[:-1]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:46:53.129755Z",
     "start_time": "2019-06-15T17:46:52.759684Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "with open('../data/desc.json') as up:\n",
    "    data_json = json.load(up)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:46:55.211463Z",
     "start_time": "2019-06-15T17:46:55.182008Z"
    }
   },
   "outputs": [],
   "source": [
    "def accuracy(outputs, targets):\n",
    "    with torch.no_grad():\n",
    "        batch_size = outputs[0].size(0)\n",
    "        \n",
    "        output_idx = []\n",
    "        for output in outputs:\n",
    "            _, pred = output.topk(1, 1, True, True)\n",
    "            # pred = pred\n",
    "            pred = pred.t().flatten()\n",
    "            output_idx.append(pred.data.cpu().numpy())\n",
    "        \n",
    "        output_idx = np.vstack(output_idx)\n",
    "        targets = [x.data.cpu().numpy() for x in targets]\n",
    "        targets = np.vstack(targets)\n",
    "        return ((targets == output_idx).mean(0) == 1).mean()\n",
    "    \n",
    "def train(train_loader, model, criterion, optimizer, epoch):\n",
    "    model.train()\n",
    "    \n",
    "    iterator = tqdm(train_loader)\n",
    "    for input,target0,target1,target2,target3,target4,target5,target6,target7,target8,target9 in tqdm_notebook(train_loader):\n",
    "        optimizer.zero_grad()\n",
    "        \n",
    "        input = input.cuda(non_blocking=True)\n",
    "        target0 = target0.cuda(non_blocking=True)\n",
    "        target1 = target1.cuda(non_blocking=True)\n",
    "        target2 = target2.cuda(non_blocking=True)\n",
    "        target3 = target3.cuda(non_blocking=True)\n",
    "        target4 = target4.cuda(non_blocking=True)\n",
    "        target5 = target5.cuda(non_blocking=True)\n",
    "        target6 = target6.cuda(non_blocking=True)\n",
    "        target7 = target7.cuda(non_blocking=True)\n",
    "        target8 = target8.cuda(non_blocking=True)\n",
    "        target9 = target9.cuda(non_blocking=True)\n",
    "\n",
    "        # compute output\n",
    "        output0,output1,output2,output3,output4,output5,output6,output7,output8,output9 = model(input)\n",
    "        loss0 = criterion(output0, target0)\n",
    "        loss1 = criterion(output1, target1)\n",
    "        loss2 = criterion(output2, target2)\n",
    "        loss3 = criterion(output3, target3)\n",
    "        loss4 = criterion(output4, target4)\n",
    "        loss5 = criterion(output5, target5)\n",
    "        loss6 = criterion(output6, target6)\n",
    "        loss7 = criterion(output7, target7)\n",
    "        loss8 = criterion(output8, target8)\n",
    "        loss9 = criterion(output9, target9)\n",
    "            \n",
    "        loss = (loss0+loss1+loss2+loss3+loss4+loss5+loss6+loss7+loss8+loss9)/10.0\n",
    "        \n",
    "        # measure accuracy and record loss\n",
    "        acc = accuracy([output0,output1,output2,output3,output4,output5,output6,output7,output8,output9], \n",
    "                        [target0,target1,target2,target3,target4,target5,target6,target7,target8,target9])\n",
    "        \n",
    "        # print(acc)\n",
    "        # status = \"loss_mean: {}; ACC: {}\".format(np.mean([acc0,acc1,acc2,acc3,acc4,acc5,acc6,acc7,acc8,acc9]), \n",
    "        #                                          loss.item())\n",
    "        # iterator.set_description(status)\n",
    "            \n",
    "        \n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "\n",
    "def validate(val_loader, model, criterion):\n",
    "    model.eval()\n",
    "    \n",
    "    val_acc = []\n",
    "    val_loss = []\n",
    "    \n",
    "    with torch.no_grad():\n",
    "        for i, (input,target0,target1,target2,target3,target4,target5,target6,target7,target8,target9) in enumerate(val_loader):\n",
    "            input = input.cuda(non_blocking=True)\n",
    "            target0 = target0.cuda(non_blocking=True)\n",
    "            target1 = target1.cuda(non_blocking=True)\n",
    "            target2 = target2.cuda(non_blocking=True)\n",
    "            target3 = target3.cuda(non_blocking=True)\n",
    "            target4 = target4.cuda(non_blocking=True)\n",
    "            target5 = target5.cuda(non_blocking=True)\n",
    "            target6 = target6.cuda(non_blocking=True)\n",
    "            target7 = target7.cuda(non_blocking=True)\n",
    "            target8 = target8.cuda(non_blocking=True)\n",
    "            target9 = target9.cuda(non_blocking=True)\n",
    "\n",
    "            # compute output\n",
    "            output0,output1,output2,output3,output4,output5,output6,output7,output8,output9 = model(input)\n",
    "            loss0 = criterion(output0, target0)\n",
    "            loss1 = criterion(output1, target1)\n",
    "            loss2 = criterion(output2, target2)\n",
    "            loss3 = criterion(output3, target3)\n",
    "            loss4 = criterion(output4, target4)\n",
    "            loss5 = criterion(output5, target5)\n",
    "            loss6 = criterion(output6, target6)\n",
    "            loss7 = criterion(output7, target7)\n",
    "            loss8 = criterion(output8, target8)\n",
    "            loss9 = criterion(output9, target9)\n",
    "            \n",
    "            loss = loss0+loss1+loss2+loss3+loss4+loss5+loss6+loss7+loss8+loss9\n",
    "            # measure accuracy and record loss\n",
    "            acc = accuracy([output0,output1,output2,output3,output4,output5,output6,output7,output8,output9], \n",
    "                            [target0,target1,target2,target3,target4,target5,target6,target7,target8,target9])\n",
    "            \n",
    "            val_acc.append(acc)\n",
    "            val_loss.append(loss.item())\n",
    "        print('VAL', np.mean(val_acc), np.mean(val_loss))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T18:10:41.413333Z",
     "start_time": "2019-06-15T17:46:57.788157Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\r",
      "  0%|          | 0/370 [00:00<?, ?it/s]"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "0eac38b1d96b4d0386b2e9a0860624d7",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.8873212442871885 0.7798011993107042\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a7d70228151b443a9cc8055017567da8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9440365619932182 0.3888101142488028\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "69c9ad9f8d4b49fbbd66873010bc5892",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.968000884564352 0.2243058097205664\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "edea4f44cb98456983c9b242693cd4d9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9780554326993955 0.18222678168431708\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b5964db6ea9c44678086c7b457fe6ec3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9845864661654137 0.13463027490989157\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aa61fec85eb64d40b17d75679c93206a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9851982898422526 0.1544800761498903\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a8f4f574f91745a28a3737b6df325e9b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9883458646616541 0.12649611391029075\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e393c70fc95a468aa3c65d3be096bc79",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9842105263157894 0.13341617589130214\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f9fa0e5e8e4d4697a6c74b0d2bcca542",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9909774436090224 0.11875742517019573\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "89d3a8d19d0d4a5bbffea46578f5e4d3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9894736842105264 0.11885783802963008\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "12e70d18570340a2b1e5d24fa34ade22",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9919652071354857 0.1097700136076463\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "98e85c8e6a604ea787ed40398cfd767d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.988205808639245 0.11618428096469295\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "1557b00952ba4d2cb2024bea348c776a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "\n",
      "  0%|          | 0/370 [00:00<?, ?it/s]\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A\u001b[A"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "VAL 0.9908373875866134 0.10655048229780636\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "834c7a07097d48b09080a4e4f2bc8f21",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=370), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-6-a8d916200a69>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     31\u001b[0m \u001b[0mscheduler\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0moptim\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlr_scheduler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mStepLR\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_size\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgamma\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m0.85\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     32\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mepoch_idx\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m20\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 33\u001b[0;31m     \u001b[0mtrain\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0moptimizer\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mepoch_idx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     34\u001b[0m     \u001b[0mvalidate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mval_loader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcriterion\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     35\u001b[0m     \u001b[0mscheduler\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-5-8bbb0a9ab215>\u001b[0m in \u001b[0;36mtrain\u001b[0;34m(train_loader, model, criterion, optimizer, epoch)\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0miterator\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m     \u001b[0;32mfor\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget0\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget1\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget2\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget3\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget4\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget5\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget6\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget7\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget8\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mtarget9\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_loader\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m         \u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm_notebook.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    219\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    220\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 221\u001b[0;31m             \u001b[0;32mfor\u001b[0m \u001b[0mobj\u001b[0m \u001b[0;32min\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtqdm_notebook\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__iter__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    222\u001b[0m                 \u001b[0;31m# return super(tqdm...) will not catch exception\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    223\u001b[0m                 \u001b[0;32myield\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36m__iter__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1040\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mn\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1041\u001b[0m                         \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_lock\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1042\u001b[0;31m                             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdisplay\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1043\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1044\u001b[0m                         \u001b[0;31m# If no `miniters` was specified, adjust automatically\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm.py\u001b[0m in \u001b[0;36mdisplay\u001b[0;34m(self, msg, pos)\u001b[0m\n\u001b[1;32m   1313\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mpos\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1314\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmoveto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpos\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1315\u001b[0;31m         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msp\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmsg\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mmsg\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1316\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mpos\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1317\u001b[0m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmoveto\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mpos\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/tqdm/_tqdm_notebook.py\u001b[0m in \u001b[0;36mprint_status\u001b[0;34m(s, close, bar_style, desc)\u001b[0m\n\u001b[1;32m    162\u001b[0m                 \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'||'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m''\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# remove inesthetical pipes\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    163\u001b[0m                 \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mescape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m  \u001b[0;31m# html escape special characters (like '?')\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 164\u001b[0;31m                 \u001b[0mptext\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    165\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    166\u001b[0m             \u001b[0;31m# Change bar style\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/traitlets/traitlets.py\u001b[0m in \u001b[0;36m__set__\u001b[0;34m(self, obj, value)\u001b[0m\n\u001b[1;32m    583\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mTraitError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'The \"%s\" trait is read-only.'\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    584\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 585\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    586\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    587\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_validate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/traitlets/traitlets.py\u001b[0m in \u001b[0;36mset\u001b[0;34m(self, obj, value)\u001b[0m\n\u001b[1;32m    572\u001b[0m             \u001b[0;31m# we explicitly compare silent to True just in case the equality\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    573\u001b[0m             \u001b[0;31m# comparison above returns something other than True/False\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 574\u001b[0;31m             \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_notify_trait\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mold_value\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnew_value\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    575\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    576\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__set__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/traitlets/traitlets.py\u001b[0m in \u001b[0;36m_notify_trait\u001b[0;34m(self, name, old_value, new_value)\u001b[0m\n\u001b[1;32m   1137\u001b[0m             \u001b[0mnew\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnew_value\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1138\u001b[0m             \u001b[0mowner\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1139\u001b[0;31m             \u001b[0mtype\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'change'\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1140\u001b[0m         ))\n\u001b[1;32m   1141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/ipywidgets/widgets/widget.py\u001b[0m in \u001b[0;36mnotify_change\u001b[0;34m(self, change)\u001b[0m\n\u001b[1;32m    597\u001b[0m             \u001b[0;32mif\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkeys\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_should_send_property\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mgetattr\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    598\u001b[0m                 \u001b[0;31m# Send new state to front-end\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 599\u001b[0;31m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    600\u001b[0m         \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mWidget\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnotify_change\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchange\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    601\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/ipywidgets/widgets/widget.py\u001b[0m in \u001b[0;36msend_state\u001b[0;34m(self, key)\u001b[0m\n\u001b[1;32m    482\u001b[0m             \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffer_paths\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_remove_buffers\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mstate\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    483\u001b[0m             \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m'method'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0;34m'update'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'state'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mstate\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'buffer_paths'\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mbuffer_paths\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 484\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_send\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbuffers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    485\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mget_state\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdrop_defaults\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/ipywidgets/widgets/widget.py\u001b[0m in \u001b[0;36m_send\u001b[0;34m(self, msg, buffers)\u001b[0m\n\u001b[1;32m    727\u001b[0m         \u001b[0;34m\"\"\"Sends a message to the model in the front-end.\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    728\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomm\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 729\u001b[0;31m             \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcomm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbuffers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    730\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    731\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m_repr_keys\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/ipykernel/comm/comm.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, data, metadata, buffers)\u001b[0m\n\u001b[1;32m    119\u001b[0m         \u001b[0;34m\"\"\"Send a message to the frontend-side version of this comm\"\"\"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    120\u001b[0m         self._publish_msg('comm_msg',\n\u001b[0;32m--> 121\u001b[0;31m             \u001b[0mdata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetadata\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmetadata\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbuffers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    122\u001b[0m         )\n\u001b[1;32m    123\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/ipykernel/comm/comm.py\u001b[0m in \u001b[0;36m_publish_msg\u001b[0;34m(self, msg_type, data, metadata, buffers, **keys)\u001b[0m\n\u001b[1;32m     69\u001b[0m             \u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mkernel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_parent_header\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     70\u001b[0m             \u001b[0mident\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtopic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 71\u001b[0;31m             \u001b[0mbuffers\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mbuffers\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     72\u001b[0m         )\n\u001b[1;32m     73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jupyter_client/session.py\u001b[0m in \u001b[0;36msend\u001b[0;34m(self, stream, msg_or_type, content, parent, ident, buffers, track, header, metadata)\u001b[0m\n\u001b[1;32m    735\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madapt_version\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    736\u001b[0m             \u001b[0mmsg\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0madapt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0madapt_version\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 737\u001b[0;31m         \u001b[0mto_send\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mserialize\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mident\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    738\u001b[0m         \u001b[0mto_send\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mextend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbuffers\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    739\u001b[0m         \u001b[0mlongest\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m[\u001b[0m \u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0ms\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mto_send\u001b[0m \u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jupyter_client/session.py\u001b[0m in \u001b[0;36mserialize\u001b[0;34m(self, msg, ident)\u001b[0m\n\u001b[1;32m    633\u001b[0m             \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"Content incorrect type: %s\"\u001b[0m\u001b[0;34m%\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcontent\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    634\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 635\u001b[0;31m         real_message = [self.pack(msg['header']),\n\u001b[0m\u001b[1;32m    636\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'parent_header'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    637\u001b[0m                         \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpack\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmsg\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'metadata'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jupyter_client/session.py\u001b[0m in \u001b[0;36m<lambda>\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m    101\u001b[0m \u001b[0;31m# disallow nan, because it's not actually valid JSON\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    102\u001b[0m json_packer = lambda obj: jsonapi.dumps(obj, default=date_default,\n\u001b[0;32m--> 103\u001b[0;31m     \u001b[0mensure_ascii\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mallow_nan\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    104\u001b[0m )\n\u001b[1;32m    105\u001b[0m \u001b[0mjson_unpacker\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0ms\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mjsonapi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mloads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/zmq/utils/jsonapi.py\u001b[0m in \u001b[0;36mdumps\u001b[0;34m(o, **kwargs)\u001b[0m\n\u001b[1;32m     38\u001b[0m         \u001b[0mkwargs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'separators'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m','\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m':'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 40\u001b[0;31m     \u001b[0ms\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjsonmod\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdumps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     41\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     42\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0ms\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0municode\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/lib/python3/dist-packages/simplejson/__init__.py\u001b[0m in \u001b[0;36mdumps\u001b[0;34m(obj, skipkeys, ensure_ascii, check_circular, allow_nan, cls, indent, separators, encoding, default, use_decimal, namedtuple_as_object, tuple_as_array, bigint_as_string, sort_keys, item_sort_key, for_json, ignore_nan, int_as_string_bitcount, iterable_as_array, **kw)\u001b[0m\n\u001b[1;32m    397\u001b[0m         \u001b[0mignore_nan\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mignore_nan\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    398\u001b[0m         \u001b[0mint_as_string_bitcount\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mint_as_string_bitcount\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 399\u001b[0;31m         **kw).encode(obj)\n\u001b[0m\u001b[1;32m    400\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    401\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/lib/python3/dist-packages/simplejson/encoder.py\u001b[0m in \u001b[0;36mencode\u001b[0;34m(self, o)\u001b[0m\n\u001b[1;32m    289\u001b[0m         \u001b[0;31m# exceptions aren't as detailed.  The list call should be roughly\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    290\u001b[0m         \u001b[0;31m# equivalent to the PySequence_Fast that ''.join() would do.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 291\u001b[0;31m         \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0m_one_shot\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mTrue\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    292\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunks\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0mlist\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtuple\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    293\u001b[0m             \u001b[0mchunks\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mchunks\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/lib/python3/dist-packages/simplejson/encoder.py\u001b[0m in \u001b[0;36miterencode\u001b[0;34m(self, o, _one_shot)\u001b[0m\n\u001b[1;32m    371\u001b[0m                 self.iterable_as_array, Decimal=decimal.Decimal)\n\u001b[1;32m    372\u001b[0m         \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 373\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0m_iterencode\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mo\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    374\u001b[0m         \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    375\u001b[0m             \u001b[0mkey_memo\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclear\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/usr/local/lib/python3.6/dist-packages/jupyter_client/jsonutil.py\u001b[0m in \u001b[0;36mdate_default\u001b[0;34m(obj)\u001b[0m\n\u001b[1;32m     89\u001b[0m     \u001b[0;32mif\u001b[0m \u001b[0misinstance\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdatetime\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     90\u001b[0m         \u001b[0mobj\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_ensure_tzinfo\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 91\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0misoformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreplace\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'+00:00'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m'Z'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     92\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     93\u001b[0m         \u001b[0;32mraise\u001b[0m \u001b[0mTypeError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"%r is not JSON serializable\"\u001b[0m \u001b[0;34m%\u001b[0m \u001b[0mobj\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "    train_loader = torch.utils.data.DataLoader(\n",
    "        QRDataset(data_json['fold11_train'],\n",
    "                transforms.Compose([\n",
    "                            transforms.RandomAffine(5),\n",
    "                            transforms.ColorJitter(hue=.05, saturation=.05),\n",
    "                            transforms.Resize((80, 320)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ), batch_size=100, shuffle=True, num_workers=10, pin_memory=True\n",
    "    )\n",
    "\n",
    "    val_loader = torch.utils.data.DataLoader(\n",
    "        QRDataset(data_json['fold11_test'],\n",
    "                transforms.Compose([\n",
    "                            transforms.RandomAffine(5),\n",
    "                            transforms.ColorJitter(hue=.05, saturation=.05),\n",
    "                            transforms.Resize((80, 320)),\n",
    "                            transforms.ToTensor(),\n",
    "                            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ])\n",
    "        ), batch_size=70, shuffle=True, num_workers=10, pin_memory=True\n",
    "    )\n",
    "    \n",
    "    model = RMB_Net()\n",
    "    model = model.cuda()\n",
    "    # model = nn.DataParallel(model).cuda()\n",
    "    criterion = nn.CrossEntropyLoss().cuda()\n",
    "    optimizer = torch.optim.Adam(model.parameters(), 0.001)\n",
    "    \n",
    "    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=1, gamma=0.85)\n",
    "    for epoch_idx in range(20):\n",
    "        train(train_loader, model, criterion, optimizer, epoch_idx)\n",
    "        validate(val_loader, model, criterion)\n",
    "        scheduler.step()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T16:40:34.998987Z",
     "start_time": "2019-06-15T16:40:34.983990Z"
    },
    "scrolled": true
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "dict_keys(['abc', 'train', 'test', 'pb', 'fold0_train', 'fold0_test', 'fold1_train', 'fold1_test', 'fold2_train', 'fold2_test', 'fold3_train', 'fold3_test', 'fold4_train', 'fold4_test', 'fold5_train', 'fold5_test', 'fold6_train', 'fold6_test', 'fold7_train', 'fold7_test', 'fold8_train', 'fold8_test', 'fold9_train', 'fold9_test', 'fold10_train', 'fold10_test', 'fold11_train', 'fold11_test', 'fold12_train', 'fold12_test', 'fold13_train', 'fold13_test', 'fold14_train', 'fold14_test'])"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "data_json.keys()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T15:22:21.939964Z",
     "start_time": "2019-06-15T15:22:21.474987Z"
    }
   },
   "outputs": [],
   "source": [
    "model = RMB_Net()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T16:54:47.004845Z",
     "start_time": "2019-06-15T16:54:46.998377Z"
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "data = QRDataset(data_json['fold11_train'],\n",
    "                transforms.Compose([\n",
    "                            # transforms.RandomAffine(10),\n",
    "                            transforms.ColorJitter(hue=.05, saturation=.05),\n",
    "                            transforms.Resize((80, 320)),\n",
    "                            # transforms.ToTensor(),\n",
    "                            # transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])\n",
    "            ]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 70,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:12:55.953636Z",
     "start_time": "2019-06-15T17:12:55.943774Z"
    }
   },
   "outputs": [],
   "source": [
    "def accuracy(outputs, targets):\n",
    "    with torch.no_grad():\n",
    "        batch_size = outputs[0].size(0)\n",
    "        \n",
    "        output_idx = []\n",
    "        for output in outputs:\n",
    "            _, pred = output.topk(1, 1, True, True)\n",
    "            # pred = pred\n",
    "            pred = pred.t().flatten()\n",
    "            output_idx.append(pred.data.cpu().numpy())\n",
    "        \n",
    "        output_idx = np.vstack(output_idx)\n",
    "        targets = [x.data.cpu().numpy() for x in targets]\n",
    "        targets = np.vstack(targets)\n",
    "        return (a == b).mean(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 73,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:13:20.914720Z",
     "start_time": "2019-06-15T17:13:20.908189Z"
    }
   },
   "outputs": [],
   "source": [
    "a, b = accuracy([torch.rand(20, 36), torch.rand(20, 36)], [torch.rand(20), torch.rand(20)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:14:03.845776Z",
     "start_time": "2019-06-15T17:14:03.838990Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "(a == b).sum(0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-15T17:13:52.770844Z",
     "start_time": "2019-06-15T17:13:52.764952Z"
    }
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(2, 20)"
      ]
     },
     "execution_count": 78,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "a.shape"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
